################################################################################
## Simulation for principal itr rules
################################################################################
library(tidyverse)
library(parallel)
select <- dplyr::select

#' Principal scores
linear_pscore_model <- function(X, beta = NULL) {

  d <- ncol(X)
  if(is.null(beta)) {
    beta <- matrix(rnorm(4 * d), nrow = d)
  }
  # natural parameters
  nat_par <- X %*% beta

  # transform into probabilities
  return(exp(nat_par) / rowSums(exp(nat_par)))
}

poly_pscore_model <- function(X, degree, beta = NULL) {

  d <- ncol(X)
  if(is.null(beta)) {
    beta <- -matrix(rnorm(3 *  (choose(d / 2 + degree, degree) - 1)), ncol = 3) * 40
    intercept <- c(.2, .15, 0)
    beta <- rbind(intercept, beta, intercept, beta)
    beta <- cbind(beta, 0)
  }
  if(degree > 1) {
    form <- as.formula(paste("~ . ^ ", degree))
  } else {
    form <- ~ .
  }
  X_cont <- X[,1:(d - 1)]
  polyx <- cbind(1, poly(X_cont, degree = degree))
  X <- cbind(polyx *  X[, d], polyx * (1 - X[,d]))
  return(linear_pscore_model(X, beta))
}

poly_pscore_model_shift <- function(X, degree, shft, beta = NULL) {

  d <- ncol(X)
  if(is.null(beta)) {
    beta <- matrix(rnorm(4 * d), nrow = d)
  }
  # X <- poly(X, degree = degree)
  if(degree > 1) {
    form <- as.formula(paste("~ . ^ ", degree))
  } else {
    form <- ~ .
  }
  X_cont <- X[,1:(d - 1)]
  polyx <- cbind(1, poly(X_cont, degree = degree))
  X <- cbind(polyx *  X[, d], polyx * (1 - X[,d]))
  pscores <- linear_pscore_model(X, beta)

  pscores[,2] <- pscores[,2] - shft
  pscores[,-2] <- pscores[,-2] + shft / 3
  return(pscores)
}


#' Generate potential outcomes from model
generate_pos <- function(X, a, pmodel, ...) {
  # generate principal scores
  pscores <- pmodel(cbind(X, a), ...)

  # generate principal strata
  strata <- apply(pscores, 1, function(p) sample(1:4, 1, prob = p))

  # get potential outcomes table
  po_mat <- matrix(c(0,0,0,1,1,0,1,1), ncol = 2, byrow = T)

  po_dat <- po_mat[strata,]

  # overall data
  full_data <- data.frame(cbind(po_dat, X, a, pscores))
  names(full_data) <- c("y0", "y1", paste0("X", 1:ncol(X)), "A", c("e00", "e10", "e01", "e11"))
  return(full_data)
}


#' Sample data from an rct
sample_rct <- function(n, d, pmodel, ...) {

  # sample covariates
  X <- matrix(2 * rnorm(n * d), ncol = d)


  # get potential outcomes
  data <- generate_pos(X, rep(1, n), pmodel, ...)

  # get treatment and observed outcomes
  trt <- sample(c(0,1), n, TRUE, c(0.5, 0.5))

  data$trt <- trt
  data$y <- trt * data$y1 + (1 - trt) * data$y0

  data %>% mutate(m1 = e11 + e10,
         m0 = e11 + e01,
         tau = m1 - m0,
         tau_pos = tau >= 0,
         pos_class = m1 + m0 >= 1,
         pscore = rep(0.5, n))

}

#' Sample data from an rct
sample_rct_with_attribute <- function(n, d, pmodel, a_beta, ...) {

  # sample covariates
  X <- matrix(2 * rnorm(n * d), ncol = d)

  # sample a protected attribute correlated with covariates
  prob_a <- 1 / (1 + exp(-X %*% a_beta))
  a <- sapply(1:n, function(i) sample(c(0,1), 1, prob = c(1 - prob_a[i], prob_a[i])))


  # get potential outcomes
  data <- generate_pos(X, a, pmodel, ...)

  # get treatment and observed outcomes
  trt <- sample(c(0,1), n, TRUE, c(0.5, 0.5))

  data$trt <- trt
  data$y <- trt * data$y1 + (1 - trt) * data$y0

  return(data)

}



#' Run many simulations in parallel
#' @param n_sims Number of simulations to run
#' @param n_cores Number of cores
#' @param ... Simulation arguments
run_sim <- function(n_sims, n_cores=1, ...) {


    cat("Number of simulations", n_sims, "\n")
    ## for each value of the correaltion run n_sim simulations and evaluate
    out <- mclapply(
        1:n_sims,
        function(j) single_run(j, ...),
        mc.cores=n_cores)

    out <- out[lapply(out, function(x) !is.null(names(x))) == TRUE]
    cat("\n\n")
    return(bind_rows(out))
}


